import copy
import torch
from epr_mappo.model.cnn import CNNBase
from epr_mappo.model.mlp import MLPBase
from epr_mappo.model.rnn import RNNLayer
from epr_mappo.model.act import ACTLayer
from epr_mappo.model.actor import Actor
from epr_mappo.util.util import get_shape_from_obs_space, get_shape_from_act_space


class AdvActor(Actor):
    def __init__(self, args, obs_space, action_space, num_agents, device=torch.device("cpu")):
        super(AdvActor, self).__init__(args, obs_space, action_space, device)
        self.super_adversary = args["super_adversary"]  # whether the adversary has defenders' policies
        obs_shape = copy.deepcopy(get_shape_from_obs_space(obs_space))

        if self.super_adversary:
            obs_shape[0] = obs_shape[0] + (num_agents - 1) * get_shape_from_act_space(action_space)

        base = CNNBase if len(obs_shape) == 3 else MLPBase
        self.base = base(args, obs_shape)

        if self.use_naive_recurrent_policy or self.use_recurrent_policy:
            self.rnn = RNNLayer(self.hidden_sizes[-1], self.hidden_sizes[-1],
                                self.recurrent_N, self.initialization_method)

        self.act = ACTLayer(action_space, self.hidden_sizes[-1],
                            self.initialization_method, self.gain, args)

        self.to(device)
